#!/usr/bin/env python3
import argparse
import json
import re
import os
import time
import math
import pickle
import requests
import urllib
import contractions
from collections import Counter
from nltk import pos_tag, word_tokenize, ne_chunk
from nltk.tree import Tree
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from functools import lru_cache

# === CONFIGURATION ===
NGRAM_RANGE        = (2, 5)
START_YEAR         = 1950
END_YEAR           = 2022
SMOOTHING          = 0
CORPORA            = {'US': 17, 'UK': 6}
MIN_FREQ_THRESHOLD = 0
BOOST_FACTOR       = 1.5
USE_LEMMATIZATION  = False

lemmatizer = WordNetLemmatizer()
STOPWORDS = set(stopwords.words('english'))

script_dir = os.path.dirname(os.path.abspath(__file__))
marker_path = os.path.join(script_dir, '../#Resources/AmE_BrE_variations.json')

# === Load Dialectal Markers ===
with open(marker_path, 'r', encoding='utf-8') as f:
    DIALECTAL_MARKERS = json.load(f)

DIALECTAL_SET = set()
for pair in DIALECTAL_MARKERS:
    us_tok = lemmatizer.lemmatize(pair["us"].lower()) if USE_LEMMATIZATION else pair["us"].lower()
    uk_tok = lemmatizer.lemmatize(pair["uk"].lower()) if USE_LEMMATIZATION else pair["uk"].lower()
    DIALECTAL_SET.add(us_tok)
    DIALECTAL_SET.add(uk_tok)

# === Persistent Cache ===
os.makedirs(os.path.join(script_dir, 'Cache'), exist_ok=True)
CACHE_FILE = os.path.join(script_dir, './Cache/ngrams_cache.pkl')
try:
    with open(CACHE_FILE, 'rb') as f:
        NGRAM_CACHE = pickle.load(f)
except FileNotFoundError:
    NGRAM_CACHE = {}

def save_cache():
    with open(CACHE_FILE, 'wb') as f:
        pickle.dump(NGRAM_CACHE, f)

# === Utility Functions ===

def fix_contractions(text):
    """
    Expand contractions in the text.
    """
    return contractions.fix(text)

def preprocess_text(text: str, remove_punct: bool=True) -> str:
    if remove_punct:
        text = re.sub(r"[^\w\s'-]", '', text)
    text = text.lower()
    # text = fix_contractions(text)
    return text.strip()

def get_ngrams(text: str, n: int):
    tokens = re.findall(r"\b[\w'-]+\b", text)
    return [' '.join(tokens[i:i+n]) for i in range(len(tokens)-n+1)]

def is_all_stopwords(ngram: str) -> bool:
    return all(tok in STOPWORDS for tok in ngram.split())

@lru_cache(maxsize=100000)
def contains_named_entity(ngram: str) -> bool:
    tags = pos_tag(word_tokenize(ngram))
    tree = ne_chunk(tags)
    return any(isinstance(sub, Tree) and sub.label() in {"PERSON", "GPE", "ORGANIZATION"} for sub in tree)

def get_ngram_freq(ngram: str, corpus: int, start_year: int, end_year: int, smoothing: int) -> float:
    """
      - aggregate across all case variants returned by the API
      - add retry with exponential backoff; cache zero only after all tries fail
    """
    key = (ngram, corpus, start_year, end_year, smoothing)
    if key in NGRAM_CACHE:
        return NGRAM_CACHE[key]

    q = urllib.parse.quote(ngram)
    url = (
        f'https://books.google.com/ngrams/json?content={q}'
        f'&year_start={start_year}&year_end={end_year}'
        f'&corpus={corpus}&smoothing={smoothing}&case_insensitive=true'
    )

    # retry/backoff
    backoff = 0.5
    for attempt in range(3):
        try:
            r = requests.get(url, timeout=3)
            if r.status_code == 200:
                data = r.json()
                if data:
                    total = 0.0
                    count = 0
                    for entry in data:
                        ts = entry.get('timeseries', [])
                        if ts:
                            total += sum(ts) / len(ts)
                            count += 1
                    if count > 0:
                        val = total / count
                        NGRAM_CACHE[key] = val
                        return val
        except Exception:
            pass
        time.sleep(backoff)
        backoff *= 2.0

    # fallback after retries
    NGRAM_CACHE[key] = 0.0
    return 0.0

def compute_alignment_score(text: str) -> dict:
    text = preprocess_text(text, remove_punct=True)
    ctr_us, ctr_uk = 0.0, 0.0
    ngram_counts = Counter()

    tokens = re.findall(r"\b[\w'-]+\b", text)
    for n in range(NGRAM_RANGE[0], NGRAM_RANGE[1] + 1):
        for i in range(len(tokens) - n + 1):
            ng_tokens = tokens[i:i+n]
            ng = ' '.join(ng_tokens)

            if is_all_stopwords(ng):
                continue
            if contains_named_entity(ng):
                continue

            key_us = (ng, CORPORA['US'], START_YEAR, END_YEAR, SMOOTHING)
            key_uk = (ng, CORPORA['UK'], START_YEAR, END_YEAR, SMOOTHING)
            cached_us = NGRAM_CACHE.get(key_us)
            cached_uk = NGRAM_CACHE.get(key_uk)

            if cached_us is not None and cached_uk is not None and (cached_us == 0.0 or cached_uk == 0.0):
                continue

            ngram_counts[ng] += 1

    contributors = []
    for ng, cnt in ngram_counts.items():
        f_us = get_ngram_freq(ng, CORPORA['US'], START_YEAR, END_YEAR, SMOOTHING)
        f_uk = get_ngram_freq(ng, CORPORA['UK'], START_YEAR, END_YEAR, SMOOTHING)
        if f_us == 0 or f_uk == 0 or (f_us < MIN_FREQ_THRESHOLD and f_uk < MIN_FREQ_THRESHOLD):
            continue

        lr = math.log2(f_us / f_uk)
        div = abs(f_us - f_uk) / (f_us + f_uk)

        ng_tokens = ng.split()
        tokens_proc = [lemmatizer.lemmatize(tok) if USE_LEMMATIZATION else tok for tok in ng_tokens]
        boosted = any(tok in DIALECTAL_SET for tok in tokens_proc)
        weight = div * BOOST_FACTOR if boosted else div

        contrib = abs(lr) * weight
        contributors.append((ng, f_us, f_uk, lr, round(weight, 6), round(contrib, 6)))
        if lr > 0:
            ctr_us += lr * weight
        else:
            ctr_uk += abs(lr) * weight

    total = ctr_us + ctr_uk
    if total == 0.0:
        return {
            "US_Alignment_Score": 0.0,
            "UK_Alignment_Score": 0.0,
            "Top_Contributors": []
        }
    else:
        return {
            "US_Alignment_Score": round(ctr_us / total, 4),
            "UK_Alignment_Score": round(ctr_uk / total, 4),
            "Top_Contributors": sorted(contributors, key=lambda x: x[5], reverse=True)[:15]
        }

# === CLI Entrypoint ===
def main():
    p = argparse.ArgumentParser(
        description="Compute AmE/BrE alignment scores for a given text."
    )
    p.add_argument('--text', required=True, help="The input text to analyze.")
    args = p.parse_args()

    result = compute_alignment_score(args.text)
    print(json.dumps(result, indent=2))
    save_cache()

if __name__ == '__main__':
    main()

